/*
 * Copyright (c) 2010-2011 Michael Laudati, N1 Concepts LLC.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice,
 * this list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 * this list of conditions and the following disclaimer in the documentation
 * and/or other materials provided with the distribution.
 *
 * 3. The names of the authors may not be used to endorse or promote products
 * derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES,
 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL N1
 * CONCEPTS LLC OR ANY CONTRIBUTORS TO THIS SOFTWARE BE LIABLE FOR ANY DIRECT,
 * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

package org.vngx.jsch;

import static org.vngx.jsch.constants.TransportLayerProtocol.*;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Arrays;
import org.vngx.jsch.algorithm.AlgorithmManager;
import org.vngx.jsch.algorithm.Algorithms;
import org.vngx.jsch.algorithm.Compression;
import org.vngx.jsch.algorithm.Random;
import org.vngx.jsch.cipher.Cipher;
import org.vngx.jsch.cipher.CipherManager;
import org.vngx.jsch.config.SessionConfig;
import org.vngx.jsch.exception.JSchException;
import org.vngx.jsch.hash.Hash;
import org.vngx.jsch.hash.HashManager;
import org.vngx.jsch.hash.MAC;
import org.vngx.jsch.hash.MACException;
import org.vngx.jsch.kex.KexProposal;
import org.vngx.jsch.kex.KeyExchange;

/**
 * Implementation to manage the transport layer for {@code Session} to read and
 * write SSH packets to/from the socket.
 *
 * @author Michael Laudati
 */
final class SessionIO {

	/** Minimum read size in bytes. (8 bytes) */
	private final static int MIN_READ_SIZE = 8;

	/** Session instance this IO belongs to. */
	private final Session _session;
	/** Random instance for generating the random padding in outbound packets. */
	private final Random _random;
	/** Socket input stream for session's transport layer. */
	private final InputStream _sessionIn;
	/** Socket output stream for session's transport layer. */
	private final OutputStream _sessionOut;

	/** Cipher instance for decrypting inbound data from server to client. */
	private Cipher _readCipher;
	/** Cipher instance for encrypting outbound data from client to server. */
	private Cipher _writeCipher;
	/** MAC for generating MACs to validate data sent from server to client. */
	private MAC _readMac;
	/** MAC for generating MACs to send to server from client for validation. */
	private MAC _writeMac;
	/** Deflater for compressing outbound data (when compression is used). */
	private Compression _compressor;
	/** Inflater for decompressing inbound data (when compression is used). */
	private Compression _decompressor;
	/** Current sequence number of inbound packet used for validating s2c MAC. */
	private int _inSequence = 0;
	/** Current sequence number of outbound packet used for generating c2s MAC. */
	private int _outSequence = 0;
	/** Result of server-to-client MAC generated by client to compare to server. */
	private byte[] _clientMacDigest;
	/** Result of server-to-client MAC retrieved from server packet to compare to client. */
	private byte[] _serverMacDigest;
	/** Server to client cipher IV size. */
	private int _readCipherSize = MIN_READ_SIZE;
	/** Client to server cipher IV size. */
	private int _writeCipherSize = MIN_READ_SIZE;
	/** Local buffer to retrieve uncompressed length when compression is used. */
	private final int[] _uncompressLen = new int[1];
	
	
	/**
	 * Creates a new instance of <code>SessionIO</code> for the specified 
	 * <code>in</code> and <code>out</code> streams for session.
	 *
	 * @param session instance
	 * @param in stream of session socket
	 * @param out stream of session socket
	 */
	private SessionIO(Session session, InputStream in, OutputStream out) throws JSchException {
		if( session == null ) {
			throw new IllegalArgumentException("Session cannot be null");
		} else if( in == null ) {
			throw new IllegalArgumentException("InputStream cannot be null");
		} else if( out == null ) {
			throw new IllegalArgumentException("OutputStream cannot be null");
		}
		_session = session;
		_sessionIn = in;
		_sessionOut = out;
		_random = AlgorithmManager.getManager().createAlgorithm(Algorithms.RANDOM, _session);
	}

	static SessionIO createIO(Session session, InputStream in, OutputStream out) throws JSchException {
		return new SessionIO(session, in, out);
	}

	public Buffer read(final Buffer buffer) throws JSchException, IOException {
		// Reset specified buffer and read in the first block of data.
		// Implementations should decrypt the length after receiving the first 8
		// (or cipher block size, whichever is larger) bytes of a packet.
		buffer.reset();
		final int read = getByte(buffer, Math.max(MIN_READ_SIZE, _readCipherSize));
		if( _readCipher != null ) {
			_readCipher.update(buffer.buffer, 0, read, buffer.buffer, 0);
		}

		// Read total length of the SSH packet to determine how much to read in
		// RFC 4253 6.1 Maximum Packet Length - Throw exception if invalid size.
		// Implementations should check that the packet length is reasonable in
		// order for the implementation to avoid denial of service and/or buffer
		// overflow attacks.
		final int packetLen = buffer.getInt();
		if( packetLen < 5 || packetLen > Packet.MAX_SIZE ) {
			startDiscard(buffer, packetLen, Packet.MAX_SIZE, packetLen < 16 ? "too small" : "too big", SSH_DISCONNECT_PROTOCOL_ERROR);
		}

		// Determine required space needed to read in remaining packet data and
		// ensure it's a multiple of cipher size (otherwise it's corrupt)
		// packet length + 4 bytes (packet length) - size already read
		final int remaining = packetLen + 4 - read;
		if( remaining % Math.max(MIN_READ_SIZE, _readCipherSize) != 0 ) {
			startDiscard(buffer, packetLen, Packet.MAX_SIZE - _readCipherSize, "invalid size", SSH_DISCONNECT_PROTOCOL_ERROR);
		}

		// Read in rest of inbound packet from the input stream
		if( remaining > 0 ) {
			buffer.ensureCapacity(remaining);	// Always ensure buffer capacity
			getByte(buffer, remaining);
			if( _readCipher != null ) {
				_readCipher.update(buffer.buffer, read, remaining, buffer.buffer, read);
			}
		}

		// Generate MAC for packet data and compare to the MAC found at the end
		// of the packet sent from server to verify data integrity
		if( _readMac != null ) {
			_readMac.update(_inSequence);	// MAC calculation includes inbound packet sequence
			_readMac.update(buffer.buffer, 0, buffer.index);
			_readMac.doFinal(_clientMacDigest, 0);
			getByte(_serverMacDigest, 0, _serverMacDigest.length);	// Read server sent MAC
			if( !Arrays.equals(_clientMacDigest, _serverMacDigest) ) {
				if( remaining > Packet.MAX_SIZE ) {
					throw new MACException("Inbound packet is corrupt: MAC verification failed");
				}
				startDiscard(buffer, packetLen, Packet.MAX_SIZE - remaining, "MAC verification failed", SSH_DISCONNECT_MAC_ERROR);
			}
		}

		_inSequence++;	// Increment number of inbound packets (required for MAC)

		// Decompress the packet data portion if enabled
		if( _decompressor != null ) {
			int paddingSize = buffer.buffer[4];
			_uncompressLen[0] = buffer.index - 5 - paddingSize;
			byte[] uncompressed = _decompressor.uncompress(buffer.buffer, 5, _uncompressLen);
			if( uncompressed == null ) {
				throw new JSchException("Failed to decompress packet data", SSH_DISCONNECT_COMPRESSION_ERROR);
			}
			buffer.buffer = uncompressed;
			buffer.index = 5 + _uncompressLen[0];	// Index is set excluding setPadding
		}

		buffer.rewind();	// Rewind buffer to prepare for use
		return buffer;		// and return to caller
	}

	/**
	 * Detects an attack on the SSH session during a read operation and throws
	 * an appropriate exception.  The method will always complete with an
	 * exception being thrown.
	 *
	 * @param buffer to read into
	 * @param packetLength
	 * @param discard
	 * @param msg to throw in exception
	 * @param disconnect reason code to pass to server
	 * @throws JSchException if any errors occur
	 * @throws IOException if any IO errors occur
	 */
	private void startDiscard(Buffer buffer, int packetLength, int discard, String msg, int reasonCode) throws JSchException, IOException {
		// If the server-to-client cipher is not using cipher-block chaining mode
		// of operation, then the inbound SSH packet is corrupt and session should end
		if( !_readCipher.isCBC() ) {
			throw new JSchException("Inbound packet is corrupt: "+msg, reasonCode);
		}

		// Finish reading in the rest of the data from the Session's in stream
		// which needs to be discarded.
		MAC discardMac = packetLength != Packet.MAX_SIZE ? _readMac : null;
		discard -= buffer.index;
		while( discard > 0 ) {
			buffer.reset();
			int len = Math.min(discard, buffer.buffer.length);
			getByte(buffer.buffer, 0, len);
			if( discardMac != null ) {
				discardMac.update(buffer.buffer, 0, len);
			}
			discard -= len;
		}
		if( discardMac != null ) {
			discardMac.doFinal(buffer.buffer, 0);
		}
		throw new JSchException("Inbound packet is corrupt: "+msg, reasonCode);
	}

	/**
	 * Writes the packet to the outbound socket stream after applying encoding.
	 * Encoding includes compression, random setPadding, MAC hash, and cipher
	 * encryption.
	 *
	 * @param packet to send
	 * @throws Exception if any errors occur
	 */
	void write(final Packet packet) throws JSchException, IOException {
		// If compression is enabled, compress the buffer data excluding the
		// packet size and setPadding size (first 5 bytes of buffer)
		if( _compressor != null ) {
			packet.buffer.index = _compressor.compress(packet.buffer.buffer, 5, packet.buffer.index);
		}
		// Add random padding to end of packet and set packet length and pad length
		packet.setPadding(_writeCipher != null ? _writeCipherSize : 8, _random);

		// If MAC algorithm is set, add the MAC to end of packet
		if( _writeMac != null ) {
			_writeMac.update(_outSequence);
			_writeMac.update(packet.buffer.buffer, 0, packet.buffer.index);
			_writeMac.doFinal(packet.buffer.buffer, packet.buffer.index);
		}
		// Apply the encryption to the entire packet contents, excluding MAC at end
		if( _writeCipher != null ) {
			_writeCipher.update(packet.buffer.buffer, 0, packet.buffer.index, packet.buffer.buffer, 0);
		}
		// Move index to end of MAC (MAC should not have been encrypted)
		if( _writeMac != null ) {
			packet.buffer.skip(_writeMac.getBlockSize());
		}
		put(packet);	// Send packet data to session output stream
		_outSequence++;	// Increment outbound sequence after packet's been sent
	}

	/**
	 * Reads from the input stream into the specified data buffer.
	 *
	 * @param buffer
	 * @param start
	 * @param length
	 * @throws IOException
	 */
	void getByte(byte[] buffer, int start, int length) throws IOException {
		int bytesRead;
		do {
			if( (bytesRead = _sessionIn.read(buffer, start, length)) < 0 ) {
				throw new IOException("End of Session InputStream");
			}
			start += bytesRead;
			length -= bytesRead;
		} while( length > 0 );
	}

	/**
	 * Reads from the input stream into the specified data buffer.
	 *
	 * @param buffer
	 * @param length
	 * @throws IOException
	 */
	int getByte(Buffer buffer, int length) throws IOException {
		int bytesRead;
		do {
			if( (bytesRead = _sessionIn.read(buffer.buffer, buffer.index, length)) < 0 ) {
				throw new IOException("End of Session InputStream");
			}
			buffer.skip(bytesRead);	// Update buffer's internal index and
			length -= bytesRead;	// update amount left to read in and keep
		} while( length > 0 );		// looping until finished
		return bytesRead;			// Return the amount of bytes read in
	}

	/**
	 * Writes the specified packet to the output stream.
	 *
	 * @param p packet to write
	 * @throws IOException
	 */
	void put(Packet p) throws IOException {
		_sessionOut.write(p.buffer.buffer, 0, p.buffer.index);
		_sessionOut.flush();
	}

	/**
	 * Generates new keys during key exchange and sets up the required
	 * algorithms for the session including ciphers, MACs and compression
	 * implementations.
	 *
	 * @param kex to use for generating key values
	 * @throws JSchException if any errors occur
	 */
	void initNewKeys(KeyExchange kex) throws JSchException {
		KexProposal proposal = kex.getKexProposal();
		Hash hash = kex.getKexAlgorithm().getHash();
		byte[] H = kex.getKexAlgorithm().getH();
		byte[] K = kex.getKexAlgorithm().getK();
		
		try {
			// Initial IV client to server:     HASH (K || H || "A" || session_id)
			// Initial IV server to client:     HASH (K || H || "B" || session_id)
			// Encryption key client to server: HASH (K || H || "C" || session_id)
			// Encryption key server to client: HASH (K || H || "D" || session_id)
			// Integrity key client to server:  HASH (K || H || "E" || session_id)
			// Integrity key server to client:  HASH (K || H || "F" || session_id)
			Buffer buffer = new Buffer();
			buffer.putMPInt(K);
			buffer.putBytes(H);
			int letterIndex = buffer.index;
			buffer.putByte((byte) 0x41);	// 0x41 = 'A'
			buffer.putBytes(_session.getSessionId());
			hash.update(buffer.buffer, 0, buffer.index);
			byte[] c2sCipherIV = hash.digest();

			buffer.buffer[letterIndex]++;	// Increment to 'B'
			hash.update(buffer.buffer, 0, buffer.index);
			byte[] s2cCipherIV = hash.digest();

			buffer.buffer[letterIndex]++;	// Increment to 'C'
			hash.update(buffer.buffer, 0, buffer.index);
			byte[] c2sCipherKey = hash.digest();

			buffer.buffer[letterIndex]++;	// Increment to 'D'
			hash.update(buffer.buffer, 0, buffer.index);
			byte[] s2cCipherKey = hash.digest();

			buffer.buffer[letterIndex]++;	// Increment to 'E'
			hash.update(buffer.buffer, 0, buffer.index);
			byte[] c2sMacIV = hash.digest();

			buffer.buffer[letterIndex]++;	// Increment to 'F'
			hash.update(buffer.buffer, 0, buffer.index);
			byte[] s2cMacIV = hash.digest();

			// Generate server-to-client cipher instance
			_readCipher = CipherManager.getManager().createCipher(proposal.getCipherAlgStoC(), _session);
			while( _readCipher.getBlockSize() > s2cCipherKey.length ) {
				buffer.reset();
				buffer.skip(letterIndex);
				buffer.putBytes(s2cCipherKey);
				hash.update(buffer.buffer, 0, buffer.index);
				s2cCipherKey = Util.join(s2cCipherKey, hash.digest());
			}
			_readCipher.init(Cipher.DECRYPT_MODE, s2cCipherKey, s2cCipherIV);
			_readCipherSize = _readCipher.getIVSize();

			// Generate server-to-client MAC instance
			_readMac = HashManager.getManager().createMAC(proposal.getMACAlgStoC());
			_readMac.init(s2cMacIV);
			_clientMacDigest = new byte[_readMac.getBlockSize()];
			_serverMacDigest = new byte[_readMac.getBlockSize()];

			// Generate client-to-server cipher instance
			_writeCipher = CipherManager.getManager().createCipher(proposal.getCipherAlgCtoS(), _session);
			while( _writeCipher.getBlockSize() > c2sCipherKey.length ) {
				buffer.reset();
				buffer.skip(letterIndex);
				buffer.putBytes(c2sCipherKey);
				hash.update(buffer.buffer, 0, buffer.index);
				c2sCipherKey = Util.join(c2sCipherKey, hash.digest());
			}
			_writeCipher.init(Cipher.ENCRYPT_MODE, c2sCipherKey, c2sCipherIV);
			_writeCipherSize = _writeCipher.getIVSize();

			// Generate client-to-server MAC instance
			_writeMac = HashManager.getManager().createMAC(proposal.getMACAlgCtoS());
			_writeMac.init(c2sMacIV);

			// Generate inflater/deflater instances for compression
			initCompressor(proposal.getCompressionAlgCtoS());
			initDecompressor(proposal.getCompressionAlgStoC());
		} catch(Exception e) {
			throw new JSchException("Failed to initialize new keys", e);
		}
		kex.kexCompleted();	// No longer in key exchange
	}

	/**
	 * Initializes the <code>Compression</code> instance for deflating
	 * compressed data sent from client to server.  If the compression type is
	 * 'none', then the deflater instance is set to null.
	 *
	 * @param method
	 * @throws JSchException
	 */
	void initCompressor(String method) throws JSchException {
		if( Compression.COMPRESSION_NONE.equals(method) ) {
			_compressor = null;
		} else if( Compression.COMPRESSION_ZLIB.equals(method) ||
				(_session.isAuthenticated() && Compression.COMPRESSION_ZLIB_OPENSSH.equals(method)) ) {
			try {
				_compressor = AlgorithmManager.getManager().createAlgorithm(method, _session);
				_compressor.init(Compression.COMPRESS_MODE, _session.getConfig().getInteger(SessionConfig.COMPRESSION_LEVEL));
			} catch(Exception e) {
				throw new JSchException("Failed to initialize deflater, method: "+method, e);
			}
		}
	}

	/**
	 * Initializes the <code>Compression</code> instance for inflating
	 * compressed data sent from server to client.  If the compression type is
	 * 'none', then the inflater instance is set to null.
	 *
	 * @param method of compression ('none', 'zlib', etc)
	 * @throws JSchException if any errors occur
	 */
	void initDecompressor(String method) throws JSchException {
		if( Compression.COMPRESSION_NONE.equals(method) ) {
			_decompressor = null;
		} else if( Compression.COMPRESSION_ZLIB.equals(method) ||
				(_session.isAuthenticated() && Compression.COMPRESSION_ZLIB_OPENSSH.equals(method)) ) {	// why only if authed?
			try {
				_decompressor = AlgorithmManager.getManager().createAlgorithm(method, _session);
				_decompressor.init(Compression.DECOMPRESS_MODE, 0);
			} catch(Exception e) {
				throw new JSchException("Failed to initialize inflater, method: "+method, e);
			}
		}
	}

	int getWriteMacSize() {
		return _writeMac != null ? _writeMac.getBlockSize() : 0;
	}

}
